-
Notifications
You must be signed in to change notification settings - Fork 113
Demonstrate dcp checkpoint save resume with context parallel #1421
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Demonstrate dcp checkpoint save resume with context parallel #1421
Conversation
Signed-off-by: Peter St. John <[email protected]>
Signed-off-by: Peter St. John <[email protected]>
Signed-off-by: Peter St. John <[email protected]>
Signed-off-by: Peter St. John <[email protected]>
|
Important Review skippedAuto reviews are disabled on this repository. Please check the settings in the CodeRabbit UI or the You can disable this status message by setting the 📝 WalkthroughWalkthroughThis pull request introduces state management and worker configuration access to context-parallel dataloader wrappers across four recipe/model implementations, updates dataset parameter naming for clarity, and modifies training scripts to support rank-aware dataloader initialization with revised checkpoint group references. Changes
Sequence Diagram(s)sequenceDiagram
participant Rank0 as Rank 0 Process
participant RankN as Non-Zero Rank<br/>Process
participant Wrapper as ContextParallel<br/>DataLoaderWrapper
participant THD as THD DataLoader
participant Collator as DataCollatorFor<br/>ContextParallel
participant Ckpt as Checkpoint<br/>Manager
Note over Rank0,RankN: Dataloader Initialization
Rank0->>Rank0: Create dataset<br/>with pad_sequences...
Rank0->>THD: Initialize THD<br/>dataloader
Rank0->>Collator: Attach collator
RankN->>RankN: Set dataloader to None
Note over Rank0,RankN: Wrapper Creation
Rank0->>Wrapper: Wrap THD +<br/>collator
RankN->>Wrapper: Wrap None with<br/>CP mesh
Note over Rank0,RankN: State Management
Rank0->>Wrapper: state_dict()
Wrapper->>THD: Delegate to<br/>underlying loader
RankN->>Wrapper: state_dict()
Wrapper->>RankN: Return {}
Note over Rank0,RankN: Checkpoint Operations
Rank0->>Ckpt: Save with<br/>cp_dp_mesh group
Rank0->>Ckpt: Load with<br/>cp_dp_mesh group
RankN->>Ckpt: Participate in<br/>group operations
Estimated code review effort🎯 3 (Moderate) | ⏱️ ~25 minutes Poem
🚥 Pre-merge checks | ✅ 2 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (2 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
✅ Actions performedReview triggered.
|
Signed-off-by: Peter St. John <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 3
🤖 Fix all issues with AI agents
In `@bionemo-recipes/models/llama3/collator.py`:
- Around line 431-437: The load_state_dict method on the collator currently
accesses state_dict["dataloader"] directly which raises KeyError if the
checkpoint omitted dataloader state; update load_state_dict to first check
cp_rank==0, then confirm hasattr(self.dataloader, "load_state_dict") and that
"dataloader" is present and not None in the passed state_dict (e.g., use
"dataloader" in state_dict and state_dict.get("dataloader") is not None) before
calling self.dataloader.load_state_dict(...), and otherwise log or warn that
dataloader state was absent and skip loading to keep resume robust.
In `@bionemo-recipes/recipes/llama3_native_te/dataset.py`:
- Around line 139-141: The docstring for the parameter
pad_sequences_to_be_divisible_by incorrectly states "Default: 16" while the
function signature (pad_sequences_to_be_divisible_by: int | None = None)
defaults to None; update the docstring in the dataset function to reflect the
actual default (e.g., "Default: None") or remove the default mention entirely so
the parameter description matches the signature.
In
`@bionemo-recipes/recipes/llama3_native_te/tests/test_distributed_checkpointing.py`:
- Around line 612-626: Phase 2 is composing the wrong config (uses
config_name="L0_sanity") so the checkpoint-specific settings are not applied;
update the compose call that creates phase2_config (inside the
initialize_config_dir block) to use config_name="L0_sanity_cp" instead, leaving
the same overrides (checkpoint.* and dataset.*) intact so the resume run picks
up the checkpoint-specific configuration.
🧹 Nitpick comments (2)
bionemo-recipes/recipes/llama3_native_te/collator.py (1)
418-451: LGTM! Implementation matches the established pattern across collator modules.The state management methods are correctly implemented and consistent with the other collator files (esm2_native_te, models/esm2, models/llama3).
Consider consolidating the
ContextParallelDataLoaderWrapperclass into a shared module to reduce code duplication across the four collator files. The module docstring already notes this code "should eventually get moved to a separate package" - this PR would be a good opportunity to track that as a follow-up.bionemo-recipes/recipes/llama3_native_te/tests/test_dataset.py (1)
798-825: Consider extracting a small helper for CP THD setup to avoid drift.
This block duplicates the test_cp_dataloader construction; a tiny helper would keep the CP padding + collator wiring consistent.
📜 Review details
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (11)
bionemo-recipes/models/esm2/src/esm/collator.pybionemo-recipes/models/llama3/collator.pybionemo-recipes/recipes/esm2_native_te/collator.pybionemo-recipes/recipes/llama3_native_te/collator.pybionemo-recipes/recipes/llama3_native_te/dataset.pybionemo-recipes/recipes/llama3_native_te/hydra_config/L0_sanity_cp.yamlbionemo-recipes/recipes/llama3_native_te/hydra_config/defaults.yamlbionemo-recipes/recipes/llama3_native_te/tests/test_dataset.pybionemo-recipes/recipes/llama3_native_te/tests/test_distributed_checkpointing.pybionemo-recipes/recipes/llama3_native_te/tests/test_train_two_gpu.pybionemo-recipes/recipes/llama3_native_te/train_fsdp2_cp.py
🧰 Additional context used
📓 Path-based instructions (2)
**/*.py
📄 CodeRabbit inference engine (.cursorrules)
**/*.py: Fix Python linter errors immediately using Ruff for linting and formatting (configured with line-length: 119 in pyproject.toml), and verify all auto-fixes are appropriate
Ensure all Python files follow Google-style docstrings (pydocstyle convention)
Follow import sorting configuration as per isort with 2 lines after imports
Use Pyright for type checking as configured in pyproject.toml
Files:
bionemo-recipes/recipes/llama3_native_te/collator.pybionemo-recipes/models/llama3/collator.pybionemo-recipes/recipes/llama3_native_te/train_fsdp2_cp.pybionemo-recipes/recipes/llama3_native_te/tests/test_train_two_gpu.pybionemo-recipes/recipes/llama3_native_te/dataset.pybionemo-recipes/recipes/esm2_native_te/collator.pybionemo-recipes/models/esm2/src/esm/collator.pybionemo-recipes/recipes/llama3_native_te/tests/test_dataset.pybionemo-recipes/recipes/llama3_native_te/tests/test_distributed_checkpointing.py
{**/*test*.py,**/__init__.py}
📄 CodeRabbit inference engine (.cursorrules)
Ensure test files and
__init__.pyfiles respect relaxed linting rules as configured in pyproject.toml
Files:
bionemo-recipes/recipes/llama3_native_te/tests/test_train_two_gpu.pybionemo-recipes/recipes/llama3_native_te/tests/test_dataset.pybionemo-recipes/recipes/llama3_native_te/tests/test_distributed_checkpointing.py
🧠 Learnings (1)
📚 Learning: 2025-08-28T16:40:04.315Z
Learnt from: pstjohn
Repo: NVIDIA/bionemo-framework PR: 1078
File: recipes/esm2_native_te_mfsdp/train_ddp.py:103-108
Timestamp: 2025-08-28T16:40:04.315Z
Learning: PyTorch DistributedDataParallel constructor accepts a device_mesh parameter in recent versions, which supports advanced distributed training scenarios and nvFSDP configurations.
Applied to files:
bionemo-recipes/recipes/llama3_native_te/train_fsdp2_cp.pybionemo-recipes/recipes/llama3_native_te/tests/test_dataset.py
🧬 Code graph analysis (4)
bionemo-recipes/recipes/llama3_native_te/collator.py (2)
bionemo-recipes/models/llama3/collator.py (3)
state_dict(418-429)load_state_dict(431-442)num_workers(445-450)bionemo-recipes/recipes/esm2_native_te/collator.py (3)
state_dict(418-429)load_state_dict(431-442)num_workers(445-450)
bionemo-recipes/models/esm2/src/esm/collator.py (3)
bionemo-recipes/models/llama3/collator.py (3)
state_dict(418-429)load_state_dict(431-442)num_workers(445-450)bionemo-recipes/recipes/esm2_native_te/collator.py (3)
state_dict(418-429)load_state_dict(431-442)num_workers(445-450)bionemo-recipes/recipes/llama3_native_te/collator.py (3)
state_dict(418-429)load_state_dict(431-442)num_workers(445-450)
bionemo-recipes/recipes/llama3_native_te/tests/test_dataset.py (1)
bionemo-recipes/recipes/llama3_native_te/collator.py (3)
ContextParallelDataLoaderWrapper(335-450)DataCollatorForContextParallel(279-332)num_workers(445-450)
bionemo-recipes/recipes/llama3_native_te/tests/test_distributed_checkpointing.py (2)
bionemo-recipes/recipes/llama3_native_te/train_fsdp2_cp.py (1)
main(44-233)bionemo-recipes/recipes/llama3_native_te/tests/conftest.py (1)
recipe_path(27-29)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (5)
- GitHub Check: unit-tests (models/esm2)
- GitHub Check: unit-tests (models/llama3)
- GitHub Check: unit-tests (recipes/llama3_native_te)
- GitHub Check: unit-tests (recipes/esm2_native_te)
- GitHub Check: pre-commit
🔇 Additional comments (16)
bionemo-recipes/recipes/llama3_native_te/hydra_config/L0_sanity_cp.yaml (1)
7-11: LGTM! Configuration changes align well with BSHD context-parallel setup.The changes correctly configure BSHD format with causal masking, and the inline comments helpfully document the THD alternatives. The
use_sequence_packing: falsesetting is appropriate since BSHD format doesn't require sequence packing.bionemo-recipes/recipes/esm2_native_te/collator.py (1)
418-451: LGTM! State management and worker introspection properly implemented for CP-wrapped dataloader.The implementation correctly:
- Restricts state operations to CP rank 0 (where the dataloader exists)
- Gracefully handles dataloaders that don't support state_dict/load_state_dict
- Returns sensible defaults (empty dict, 0 workers) for non-rank-0 processes
This enables checkpoint save/resume functionality for context-parallel training as intended by the PR.
bionemo-recipes/models/esm2/src/esm/collator.py (1)
418-451: LGTM! Consistent implementation with other collator modules.State management and worker introspection methods are correctly implemented, maintaining consistency with the other
ContextParallelDataLoaderWrapperimplementations across the codebase.bionemo-recipes/recipes/llama3_native_te/dataset.py (2)
120-121: LGTM! Parameter rename improves semantic clarity.The renamed parameter
pad_sequences_to_be_divisible_bybetter conveys its purpose for context parallelism alignment. The mapping to HuggingFace'spad_to_multiple_ofparameter is correctly maintained internally.Also applies to: 168-171
220-220: LGTM! THD dataloader correctly wires the new parameter.The
pad_sequences_to_be_divisible_byparameter is properly passed toDataCollatorWithFlattening, enabling context-parallel compatible padding in the THD flow.Also applies to: 268-271
bionemo-recipes/recipes/llama3_native_te/train_fsdp2_cp.py (4)
31-32: Import updates look good.
137-146: Process group wiring for CP/DP mesh looks correct.
197-209: Checkpoint save now aligned to CP/DP mesh group.
113-133: No action required. Bothcreate_bshd_dataloaderandcreate_thd_dataloaderexplicitly accept thepad_sequences_to_be_divisible_byparameter (declared asint | None = Nonein both function signatures). Thesetdefaultcall is safe for both dataloader paths.Likely an incorrect or invalid review comment.
bionemo-recipes/recipes/llama3_native_te/tests/test_train_two_gpu.py (2)
194-213: Nice coverage for the BSHD CP path.
218-234: THD CP test config updates look consistent.bionemo-recipes/recipes/llama3_native_te/tests/test_distributed_checkpointing.py (2)
41-41: Import looks good.
714-725: The concern about rank 1 dataloader files may be unfounded.
save_dataloader()callstorch.save()unconditionally on all ranks (line 413), even whenContextParallelDataLoaderWrapper.state_dict()returns{}on non-zero CP ranks. The function always addsnum_workersandnum_ranksmetadata before saving, ensuring dataloader files are created for all ranks including rank 1. The test assertions should succeed.Likely an incorrect or invalid review comment.
bionemo-recipes/recipes/llama3_native_te/hydra_config/defaults.yaml (1)
25-25: Config addition looks good.bionemo-recipes/recipes/llama3_native_te/tests/test_dataset.py (2)
27-28: Imports align with the new CP wrapper flow.
707-728: Rank‑0‑only THD construction + CP wrapper looks correct.
✏️ Tip: You can disable this entire section by setting review_details to false in your review settings.
bionemo-recipes/recipes/llama3_native_te/tests/test_distributed_checkpointing.py
Show resolved
Hide resolved
Signed-off-by: Peter St. John <[email protected]>
Signed-off-by: Peter St. John <[email protected]>
d214d6e to
80e29e5
Compare
jomitchellnv
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
Makes it easier to run BSHD context parallel runs in the llama3 recipe for local testing, and adds checkpoint save/resume checks to the llama3 recipe BIO-8 <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit ## Release Notes * **New Features** * Added checkpoint save/restore capabilities and worker configuration querying for context-parallel data loading * Introduced new configuration parameter for sequence padding control * **Improvements** * Enhanced distributed training support with improved checkpoint integration for context parallelism * **Tests** * Added comprehensive integration tests for distributed checkpointing with context parallelism * Added multi-GPU training tests with different attention format configurations <sub>✏️ Tip: You can customize this high-level summary in your review settings.</sub> <!-- end of auto-generated comment: release notes by coderabbit.ai --> --------- Signed-off-by: Peter St. John <[email protected]>
Makes it easier to run BSHD context parallel runs in the llama3 recipe for local testing, and adds checkpoint save/resume checks to the llama3 recipe
BIO-8
Summary by CodeRabbit
Release Notes
New Features
Improvements
Tests
✏️ Tip: You can customize this high-level summary in your review settings.